#!/usr/bin/env python3

# Python script to run and analyse MMS test

from __future__ import division
from __future__ import print_function
from builtins import str

from boututils.run_wrapper import shell, launch_safe
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, concatenate



print("Making MMS test")
shell("make > make.log")

# List of NX values to use
nxlist = [16, 32, 64, 128, 256]

path="mms-slab2d"

nxdx = 128*2e-5  # Nx * dx held constant

varlist = ["Ne", "Te", "Vort"]

nproc = 4

success=True

error_2 = {}
error_inf = {}
for var in varlist:
    error_2[var]   = []  # The L2 error (RMS)
    error_inf[var] = []  # The maximum error

for nx in nxlist:
    args = "-d " + path + " mesh:nx="+str(nx+2)+" mesh:dx="+str(nxdx/nx)+" MZ="+str(nx)
    
    print("Running with " + args)

    # Delete old data
    shell("rm %s/BOUT.dmp.*.nc" % (path,))
    
    # Command to run
    cmd = "./gbs "+args
    # Launch using MPI
    s, out = launch_safe(cmd, nproc=nproc, pipe=True)

    # Save output to log file
    f = open(path+"/run.log."+str(nx), "w")
    f.write(out)
    f.close()
    
    for var in varlist:
        # Collect data
        E = collect("E_"+var, tind=[1,1], path=path)
        E = E[0,1:-1,0,:]

        l2 = sqrt(mean(E**2))
        linf = max(abs(E))
        
        error_2[var].append( l2 )
        error_inf[var].append( linf )

        print("%s : l-2 %f l-inf %f" % (var, l2, linf))

# Calculate grid spacing
dx = 1. / (array(nxlist) - 2.)

# plot errors


for var in varlist:
    order = log(error_2[var][-1] / error_2[var][-2]) / log(dx[-1] / dx[-2])
    print("Convergence order = %f" % (order))
    if 1.9 < order < 2.1:
        pass
    else:
        success=False


try:
    import matplotlib.pyplot as plt

    for var in varlist:
        plt.plot(dx, error_2[var], '-o', label=r'$l^2$ '+var)
        plt.plot(dx, error_inf[var], '-x', label=r'$l^\infty$ '+var)


        order = log(error_2[var][-1] / error_2[var][-2]) / log(dx[-1] / dx[-2])

        plt.plot(dx, error_2[var][-1]*(dx/dx[-1])**order, '--', label="Order %.1f"%(order))

    plt.legend(loc="upper left")
    plt.grid()
    
    plt.yscale('log')
    plt.xscale('log')
    
    plt.xlabel(r'Mesh spacing $\delta x$')
    plt.ylabel("Error norm")
    
    plt.savefig("norm-slab2d.pdf")
    
    #plt.show()
except:
    pass

if success:
    exit(0)
else:
    exit(1)
